package com.hanxiaozhang.graph;

import java.util.Comparator;
import java.util.HashSet;
import java.util.PriorityQueue;
import java.util.Set;

/**
 * 〈一句话功能简述〉<br>
 * 〈P算法生成最小生成树〉
 *  普里姆算法
 *
 * @author hanxinghua
 * @create 2021/9/27
 * @since 1.0.0
 */
public class Prim {

    /**
     * 比较器
     */
    public static class EdgeComparator implements Comparator<Edge> {

        @Override
        public int compare(Edge o1, Edge o2) {
            return o1.weight - o2.weight;
        }
    }

    /**
     *
     * @param graph
     * @return
     */
    public static Set<Edge> primMST(Graph graph) {
        // 解锁的边进入小根堆
        PriorityQueue<Edge> priorityQueue = new PriorityQueue<>(new EdgeComparator());
        // 被解锁出来的边
        HashSet<Node> nodeSet = new HashSet<>();
        // 已经考虑过的边，不要重复考虑
        HashSet<Edge> edgeSet = new HashSet<>();
        // 依次挑选的的边在result里
        Set<Edge> result = new HashSet<>();
        // 随便挑了一个点，这个for循环防止森林问题
        for (Node node : graph.nodes.values()) {
            // node 是开始点
            if (!nodeSet.contains(node)) {
                nodeSet.add(node);
                // 由一个点，解锁它所有相连的边
                for (Edge edge : node.edges) {
                    if(!edgeSet.contains(edge)){
                        edgeSet.add(edge);
                        priorityQueue.add(edge);
                    }
                }
                // 队列不为空
                while (!priorityQueue.isEmpty()) {
                    // 弹出解锁的边中，最小的边
                    Edge edge = priorityQueue.poll();
                    // 可能的一个新的点
                    Node toNode = edge.to;
                    // 不含有的时候，就是新的点
                    if (!nodeSet.contains(toNode)) {
                        nodeSet.add(toNode);
                        result.add(edge);
                        for (Edge nextEdge : toNode.edges) {
                            if (!edgeSet.contains(nextEdge)) {
                                edgeSet.add(nextEdge);
                                priorityQueue.add(nextEdge);
                            }
                        }
                    }
                }
            }
            //break;
        }
        return result;
    }


    /**
     *  请保证graph是连通图
     *  graph[i][j]表示点i到点j的距离，如果是系统最大值代表无路
     *  返回值是最小连通图的路径之和 ??
     *
     * @param graph
     * @return
     */
    public static int prim(int[][] graph) {
        int size = graph.length;
        int[] distances = new int[size];
        boolean[] visit = new boolean[size];
        visit[0] = true;
        for (int i = 0; i < size; i++) {
            distances[i] = graph[0][i];
        }
        int sum = 0;
        for (int i = 1; i < size; i++) {
            int minPath = Integer.MAX_VALUE;
            int minIndex = -1;
            for (int j = 0; j < size; j++) {
                if (!visit[j] && distances[j] < minPath) {
                    minPath = distances[j];
                    minIndex = j;
                }
            }
            if (minIndex == -1) {
                return sum;
            }
            visit[minIndex] = true;
            sum += minPath;
            for (int j = 0; j < size; j++) {
                if (!visit[j] && distances[j] > graph[minIndex][j]) {
                    distances[j] = graph[minIndex][j];
                }
            }
        }
        return sum;
    }

}
